/*
Copyright 2020 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package cluster

import (
	"context"
	"sort"
	"strings"
	"time"

	"github.com/blang/semver/v4"
	"github.com/pkg/errors"
	appsv1 "k8s.io/api/apps/v1"
	"k8s.io/apimachinery/pkg/util/sets"
	"k8s.io/apimachinery/pkg/util/version"
	"k8s.io/apimachinery/pkg/util/wait"
	"k8s.io/klog/v2"
	"k8s.io/utils/ptr"
	"sigs.k8s.io/controller-runtime/pkg/client"

	clusterv1 "sigs.k8s.io/cluster-api/api/core/v1beta2"
	clusterctlv1 "sigs.k8s.io/cluster-api/cmd/clusterctl/api/v1alpha3"
	"sigs.k8s.io/cluster-api/cmd/clusterctl/client/config"
	"sigs.k8s.io/cluster-api/cmd/clusterctl/client/repository"
	logf "sigs.k8s.io/cluster-api/cmd/clusterctl/log"
)

// ProviderUpgrader defines methods for supporting provider upgrade.
type ProviderUpgrader interface {
	// Plan returns a set of suggested Upgrade plans for the management cluster.
	Plan(ctx context.Context) ([]UpgradePlan, error)

	// ApplyPlan executes an upgrade following an UpgradePlan generated by clusterctl.
	ApplyPlan(ctx context.Context, opts UpgradeOptions, contract string) error

	// ApplyCustomPlan plan executes an upgrade using the UpgradeItems provided by the user.
	ApplyCustomPlan(ctx context.Context, opts UpgradeOptions, providersToUpgrade ...UpgradeItem) error
}

// UpgradePlan defines a list of possible upgrade targets for a management cluster.
type UpgradePlan struct {
	Contract  string
	Providers []UpgradeItem
}

// UpgradeOptions defines the options used to upgrade installation.
type UpgradeOptions struct {
	WaitProviders                    bool
	WaitProviderTimeout              time.Duration
	EnableCRDStorageVersionMigration bool
}

// isPartialUpgrade returns true if at least one upgradeItem in the plan does not have a target version.
func (u *UpgradePlan) isPartialUpgrade() bool {
	for _, i := range u.Providers {
		if i.NextVersion == "" {
			return true
		}
	}
	return false
}

// UpgradeItem defines a possible upgrade target for a provider in the management cluster.
type UpgradeItem struct {
	clusterctlv1.Provider
	NextVersion string
}

// UpgradeRef returns a string identifying the upgrade item; this string is derived by the provider.
func (u *UpgradeItem) UpgradeRef() string {
	return u.InstanceName()
}

type providerUpgrader struct {
	configClient                  config.Client
	proxy                         Proxy
	repositoryClientFactory       RepositoryClientFactory
	providerInventory             InventoryClient
	providerComponents            ComponentsClient
	currentContractVersion        string
	getCompatibleContractVersions func(string) sets.Set[string]
}

var _ ProviderUpgrader = &providerUpgrader{}

func (u *providerUpgrader) Plan(ctx context.Context) ([]UpgradePlan, error) {
	log := logf.Log
	log.Info("Checking new release availability...")

	providerList, err := u.providerInventory.List(ctx)
	if err != nil {
		return nil, err
	}

	// The core provider is driving all the plan logic for entire management cluster, because all the providers
	// are expected to support the same contract version or compatible onew.
	// e.g if the core provider supports v1alpha4, all the providers in the same management cluster should support v1alpha4 as well;
	// all the providers in the management cluster can upgrade to the latest release supporting v1alpha4, or if available,
	// all the providers can upgrade to the latest release supporting v1alpha5 (not supported in current clusterctl release,
	// but upgrade plan should report these options)

	// Gets the upgrade info for the core provider.
	coreProviders := providerList.FilterCore()
	if len(coreProviders) != 1 {
		return nil, errors.Errorf("invalid management cluster: there must be one core provider, found %d", len(coreProviders))
	}
	coreProvider := coreProviders[0]

	coreUpgradeInfo, err := u.getUpgradeInfo(ctx, coreProvider)
	if err != nil {
		return nil, err
	}

	// Identifies the contract version that we should consider for the management cluster update (Nb. the core provider is driving the entire management cluster).
	// This includes the current contract and the new ones available, if any.
	contractsForUpgrade := coreUpgradeInfo.getContractsForUpgrade()
	if len(contractsForUpgrade) == 0 {
		return nil, errors.Wrapf(err, "invalid metadata: unable to find the contract version implemented by the %s provider", coreProvider.InstanceName())
	}

	// Creates an UpgradePlan for each contract version considered for upgrades; each upgrade plans contains
	// an UpgradeItem for each provider defining the next available version with the target contract versions or a compatible contract version, if available.
	// e.g. v1alpha4, cluster-api --> v0.4.1, kubeadm bootstrap --> v0.4.1, aws --> v0.X.2
	// e.g. v1alpha4, cluster-api --> v0.5.1, kubeadm bootstrap --> v0.5.1, aws --> v0.Y.4 (not supported in current clusterctl release, but upgrade plan should report these options).
	ret := make([]UpgradePlan, 0)
	for _, contract := range contractsForUpgrade {
		upgradePlan, err := u.getUpgradePlan(ctx, providerList.Items, contract)
		if err != nil {
			return nil, err
		}

		// If the upgrade plan is partial (at least one upgradeItem in the plan does not have a target version) and
		// the upgrade plan requires a change of the contract for this management cluster, then drop it
		// (all the provider in a management cluster are required to change contract at the same time).
		if upgradePlan.isPartialUpgrade() && coreUpgradeInfo.currentContract != contract {
			continue
		}

		ret = append(ret, *upgradePlan)
	}

	return ret, nil
}

func (u *providerUpgrader) ApplyPlan(ctx context.Context, opts UpgradeOptions, contract string) error {
	if contract != u.currentContractVersion {
		return errors.Errorf("current version of clusterctl could only upgrade to %s contract, requested %s", u.currentContractVersion, contract)
	}

	log := logf.Log
	log.Info("Performing upgrade...")

	// Gets the upgrade plan for the selected contract version.
	providerList, err := u.providerInventory.List(ctx)
	if err != nil {
		return err
	}

	upgradePlan, err := u.getUpgradePlan(ctx, providerList.Items, contract)
	if err != nil {
		return err
	}

	// Make sure there is something to upgrade, clear providers that do not
	// need it
	for i := len(upgradePlan.Providers) - 1; i >= 0; i-- {
		if upgradePlan.Providers[i].NextVersion == "" {
			// Remove this from our plan
			upgradePlan.Providers = append(upgradePlan.Providers[:i], upgradePlan.Providers[i+1:]...)
		}
	}

	// Do the upgrade
	return u.doUpgrade(ctx, upgradePlan, opts)
}

func (u *providerUpgrader) ApplyCustomPlan(ctx context.Context, opts UpgradeOptions, upgradeItems ...UpgradeItem) error {
	log := logf.Log
	log.Info("Performing upgrade...")

	// Create a custom upgrade plan from the upgrade items, taking care of ensuring all the providers in a management
	// cluster are consistent with the contract version of the core provider (or compatible ones).
	upgradePlan, err := u.createCustomPlan(ctx, upgradeItems)
	if err != nil {
		return err
	}

	// Do the upgrade
	return u.doUpgrade(ctx, upgradePlan, opts)
}

// getUpgradePlan returns the upgrade plan for a specific set of providers/contract
// NB. this function is used both for upgrade plan and upgrade apply.
func (u *providerUpgrader) getUpgradePlan(ctx context.Context, providers []clusterctlv1.Provider, contract string) (*UpgradePlan, error) {
	compatibleContracts := u.getCompatibleContractVersions(contract)

	upgradeItems := []UpgradeItem{}
	for _, provider := range providers {
		// Gets the upgrade info for the provider.
		providerUpgradeInfo, err := u.getUpgradeInfo(ctx, provider)
		if err != nil {
			return nil, err
		}

		// Identifies the next available version for the provider with target contract versions or a compatible contract version, if available.
		nextVersion := providerUpgradeInfo.getLatestNextVersion(compatibleContracts)

		// Append the upgrade item for the provider/with the target contract.
		upgradeItems = append(upgradeItems, UpgradeItem{
			Provider:    provider,
			NextVersion: versionTag(nextVersion),
		})
	}

	return &UpgradePlan{
		Contract:  contract,
		Providers: upgradeItems,
	}, nil
}

// createCustomPlan creates a custom upgrade plan from a set of upgrade items, taking care of ensuring all the providers
// in a management cluster are consistent with the contract version of the core provider (or compatible ones).
func (u *providerUpgrader) createCustomPlan(ctx context.Context, upgradeItems []UpgradeItem) (*UpgradePlan, error) {
	// Gets the contract version of the core provider.
	// The this is required to ensure all the providers in a management cluster are consistent with the contract supported by the core provider.
	// e.g if the core provider is v1beta1, all the provider should be v1beta1 as well.

	// The target contract is derived from the current version of the core provider, or, if the core provider is included in the upgrade list,
	// from its target version.
	providerList, err := u.providerInventory.List(ctx)
	if err != nil {
		return nil, err
	}
	coreProviders := providerList.FilterCore()
	if len(coreProviders) != 1 {
		return nil, errors.Errorf("invalid management cluster: there must be one core provider, found %d", len(coreProviders))
	}
	coreProvider := coreProviders[0]

	targetCoreProviderVersion := coreProvider.Version
	for _, providerToUpgrade := range upgradeItems {
		if providerToUpgrade.InstanceName() == coreProvider.InstanceName() {
			targetCoreProviderVersion = providerToUpgrade.NextVersion
			break
		}
	}

	targetContract, err := u.getProviderContractByVersion(ctx, coreProvider, targetCoreProviderVersion)
	if err != nil {
		return nil, err
	}

	if targetContract != u.currentContractVersion {
		return nil, errors.Errorf("current version of clusterctl could only upgrade the core provider to %s contract version, requested %s", u.currentContractVersion, targetContract)
	}
	compatibleContracts := u.getCompatibleContractVersions(targetContract)

	// Builds the custom upgrade plan, by adding all the upgrade items after checking consistency with the targetContract.
	upgradeInstanceNames := sets.Set[string]{}
	upgradePlan := &UpgradePlan{
		Contract: targetContract,
	}

	for _, upgradeItem := range upgradeItems {
		// Match the upgrade item with the corresponding provider in the management cluster
		var provider *clusterctlv1.Provider
		for i := range providerList.Items {
			if providerList.Items[i].InstanceName() == upgradeItem.InstanceName() {
				provider = &providerList.Items[i]
				break
			}
		}
		if provider == nil {
			return nil, errors.Errorf("unable to perform upgrade: the provider %s in not part of the management cluster", upgradeItem.InstanceName())
		}

		if upgradeItem.Version == "" {
			upgradeItem.Version = provider.Version
		}

		// Retrieves the contract that is supported by the target version of the provider.
		contract, err := u.getProviderContractByVersion(ctx, *provider, upgradeItem.NextVersion)
		if err != nil {
			return nil, err
		}

		if !compatibleContracts.Has(contract) {
			return nil, errors.Errorf("unable to perform upgrade: the target version for the provider %s implements the %s contract version, while the core provider supports %s contract versions", upgradeItem.InstanceName(), contract, strings.Join(compatibleContracts.UnsortedList(), ", "))
		}

		upgradePlan.Providers = append(upgradePlan.Providers, upgradeItem)
		upgradeInstanceNames.Insert(upgradeItem.InstanceName())
	}

	// Before performing the upgrades, checks if other providers in the management cluster are lagging behind the target contract.
	for _, provider := range providerList.Items {
		// skip providers already included in the upgrade plan
		if upgradeInstanceNames.Has(provider.InstanceName()) {
			continue
		}

		// Retrieves the contract that is supported by the current version of the provider.
		contract, err := u.getProviderContractByVersion(ctx, provider, provider.Version)
		if err != nil {
			return nil, err
		}

		if !compatibleContracts.Has(contract) {
			return nil, errors.Errorf("unable to perform upgrade: the provider %s implements the %s contract version, while the core provider is getting updated to a version that supports %s contract versions. Please include the %[1]s provider in the upgrade", provider.InstanceName(), contract, strings.Join(compatibleContracts.UnsortedList(), ", "))
		}
	}
	return upgradePlan, nil
}

// getProviderContractByVersion returns the contract that a provider will support if updated to the given target version.
func (u *providerUpgrader) getProviderContractByVersion(ctx context.Context, provider clusterctlv1.Provider, targetVersion string) (string, error) {
	targetSemVersion, err := version.ParseSemantic(targetVersion)
	if err != nil {
		return "", errors.Wrapf(err, "failed to parse target version for the %s provider", provider.InstanceName())
	}

	// Gets the metadata for the core Provider
	upgradeInfo, err := u.getUpgradeInfo(ctx, provider)
	if err != nil {
		return "", err
	}

	releaseSeries := upgradeInfo.metadata.GetReleaseSeriesForVersion(targetSemVersion)
	if releaseSeries == nil {
		return "", errors.Errorf("invalid target version: version %s for the provider %s does not match any release series", targetVersion, provider.InstanceName())
	}
	return releaseSeries.Contract, nil
}

// getUpgradeComponents returns the provider components for the selected target version.
func (u *providerUpgrader) getUpgradeComponents(ctx context.Context, provider UpgradeItem) (repository.Components, error) {
	configRepository, err := u.configClient.Providers().Get(provider.ProviderName, provider.GetProviderType())
	if err != nil {
		return nil, err
	}

	providerRepository, err := u.repositoryClientFactory(ctx, configRepository, u.configClient)
	if err != nil {
		return nil, err
	}

	options := repository.ComponentsOptions{
		Version:         provider.NextVersion,
		TargetNamespace: provider.Namespace,
	}
	components, err := providerRepository.Components().Get(ctx, options)
	if err != nil {
		return nil, err
	}
	return components, nil
}

func (u *providerUpgrader) doUpgrade(ctx context.Context, upgradePlan *UpgradePlan, opts UpgradeOptions) error {
	// Check for multiple instances of the same provider (not supported).
	if err := u.providerInventory.CheckSingleProviderInstance(ctx); err != nil {
		return err
	}

	// Block unsupported skip upgrades for Core, Kubeadm Bootstrap, Kubeadm ControlPlane.
	// NOTE: in future we might consider extending the clusterctl contract to support enforcing of skip upgrade
	// rules for out of tree providers.
	minVersionSkew := semver.MustParse("1.10.0")
	for _, upgradeItem := range upgradePlan.Providers {
		if upgradeItem.Type != string(clusterctlv1.CoreProviderType) &&
			(upgradeItem.Type != string(clusterctlv1.BootstrapProviderType) || upgradeItem.ProviderName != config.KubeadmBootstrapProviderName) &&
			(upgradeItem.Type != string(clusterctlv1.ControlPlaneProviderType) || upgradeItem.ProviderName != config.KubeadmControlPlaneProviderName) {
			continue
		}

		currentVersion, err := semver.ParseTolerant(upgradeItem.Version)
		if err != nil {
			return errors.Wrapf(err, "failed to parse current version for %s provider", upgradeItem.InstanceName())
		}

		if currentVersion.LT(minVersionSkew) {
			continue
		}

		nextVersion, err := semver.ParseTolerant(upgradeItem.NextVersion)
		if err != nil {
			return errors.Wrapf(err, "failed to parse next version for %s provider", upgradeItem.InstanceName())
		}

		if nextVersion.Minor > currentVersion.Minor+3 {
			return errors.Errorf("upgrade for %s provider can't skip more than 3 versions", upgradeItem.InstanceName())
		}
	}

	// Ensure Providers are updated in the following order: Core, Bootstrap, ControlPlane, Infrastructure.
	providers := upgradePlan.Providers
	sort.Slice(providers, func(a, b int) bool {
		return providers[a].GetProviderType().Order() < providers[b].GetProviderType().Order()
	})

	if opts.EnableCRDStorageVersionMigration {
		// Migrate CRs to latest CRD storage version, if necessary.
		// Note: We have to do this before the providers are scaled down or deleted
		// so conversion webhooks still work.
		for _, upgradeItem := range providers {
			// If there is not a specified next version, skip it (we are already up-to-date).
			if upgradeItem.NextVersion == "" {
				continue
			}

			// Gets the provider components for the target version.
			components, err := u.getUpgradeComponents(ctx, upgradeItem)
			if err != nil {
				return err
			}

			c, err := u.proxy.NewClient(ctx)
			if err != nil {
				return err
			}

			if err := NewCRDMigrator(c).Run(ctx, components.Objs()); err != nil {
				return err
			}
		}
	}

	// Scale down all providers.
	// This is done to ensure all Pods of all "old" provider Deployments have been deleted.
	// Otherwise it can happen that a provider Pod survives the upgrade because we create
	// a new Deployment with the same selector directly after `Delete`.
	// This can lead to a failed upgrade because:
	// * new provider Pods fail to startup because they try to list resources.
	// * list resources fails, because the API server hits the old provider Pod when trying to
	//   call the conversion webhook for those resources.
	for _, upgradeItem := range providers {
		// If there is not a specified next version, skip it (we are already up-to-date).
		if upgradeItem.NextVersion == "" {
			continue
		}

		// Scale down provider.
		if err := u.scaleDownProvider(ctx, upgradeItem.Provider); err != nil {
			return err
		}
	}

	installQueue := []repository.Components{}

	// Delete old providers and deploy new ones if necessary, i.e. there is a NextVersion.
	for _, upgradeItem := range providers {
		// If there is not a specified next version, skip it (we are already up-to-date).
		if upgradeItem.NextVersion == "" {
			continue
		}

		// Gets the provider components for the target version.
		components, err := u.getUpgradeComponents(ctx, upgradeItem)
		if err != nil {
			return err
		}

		installQueue = append(installQueue, components)

		// Delete the provider, preserving CRD, namespace and the inventory.
		if err := u.providerComponents.Delete(ctx, DeleteOptions{
			Provider:         upgradeItem.Provider,
			IncludeNamespace: false,
			IncludeCRDs:      false,
			SkipInventory:    true,
		}); err != nil {
			return err
		}

		// Install the new version of the provider components.
		if err := installComponentsAndUpdateInventory(ctx, components, u.providerComponents, u.providerInventory); err != nil {
			return err
		}
	}

	installOpts := InstallOptions{
		WaitProviders:       opts.WaitProviders,
		WaitProviderTimeout: opts.WaitProviderTimeout,
	}
	return waitForProvidersReady(ctx, installOpts, installQueue, u.proxy)
}

func (u *providerUpgrader) scaleDownProvider(ctx context.Context, provider clusterctlv1.Provider) error {
	log := logf.Log
	log.Info("Scaling down", "Provider", klog.KObj(&provider), "providerVersion", &provider.Version)

	cs, err := u.proxy.NewClient(ctx)
	if err != nil {
		return err
	}

	// Fetch all Deployments belonging to a provider.
	deploymentList := &appsv1.DeploymentList{}
	if err := cs.List(ctx,
		deploymentList,
		client.InNamespace(provider.Namespace),
		client.MatchingLabels{
			clusterctlv1.ClusterctlLabel: "",
			clusterv1.ProviderNameLabel:  provider.ManifestLabel(),
		}); err != nil {
		return errors.Wrapf(err, "failed to list Deployments for provider %s", provider.Name)
	}

	// Scale down provider Deployments.
	for _, deployment := range deploymentList.Items {
		log.V(5).Info("Scaling down", "Deployment", klog.KObj(&deployment))
		if err := scaleDownDeployment(ctx, cs, deployment); err != nil {
			return err
		}
	}

	return nil
}

// scaleDownDeployment scales down a Deployment to 0 and waits until all replicas have been deleted.
func scaleDownDeployment(ctx context.Context, c client.Client, deploy appsv1.Deployment) error {
	if err := retryWithExponentialBackoff(ctx, newWriteBackoff(), func(ctx context.Context) error {
		deployment := &appsv1.Deployment{}
		if err := c.Get(ctx, client.ObjectKeyFromObject(&deploy), deployment); err != nil {
			return errors.Wrapf(err, "failed to get Deployment/%s", deploy.GetName())
		}

		// Deployment already scaled down, return early.
		if deployment.Spec.Replicas != nil && *deployment.Spec.Replicas == 0 {
			return nil
		}

		// Scale down.
		deployment.Spec.Replicas = ptr.To[int32](0)
		if err := c.Update(ctx, deployment); err != nil {
			return errors.Wrapf(err, "failed to update Deployment/%s", deploy.GetName())
		}
		return nil
	}); err != nil {
		return errors.Wrapf(err, "failed to scale down Deployment")
	}

	deploymentScaleToZeroBackOff := wait.Backoff{
		Duration: 1 * time.Second,
		Factor:   1,
		Steps:    60,
		Jitter:   0.4,
	}
	if err := retryWithExponentialBackoff(ctx, deploymentScaleToZeroBackOff, func(ctx context.Context) error {
		deployment := &appsv1.Deployment{}
		if err := c.Get(ctx, client.ObjectKeyFromObject(&deploy), deployment); err != nil {
			return errors.Wrapf(err, "failed to get Deployment/%s", deploy.GetName())
		}

		// Deployment is scaled down.
		if deployment.Status.Replicas == 0 {
			return nil
		}

		return errors.Errorf("Deployment still has %d replicas", deployment.Status.Replicas)
	}); err != nil {
		return errors.Wrapf(err, "failed to wait until Deployment is scaled down")
	}

	return nil
}

func newProviderUpgrader(configClient config.Client, proxy Proxy, repositoryClientFactory RepositoryClientFactory, providerInventory InventoryClient, providerComponents ComponentsClient, currentContractVersion string, getCompatibleContractVersions func(string) sets.Set[string]) *providerUpgrader {
	return &providerUpgrader{
		configClient:                  configClient,
		proxy:                         proxy,
		repositoryClientFactory:       repositoryClientFactory,
		providerInventory:             providerInventory,
		providerComponents:            providerComponents,
		currentContractVersion:        currentContractVersion,
		getCompatibleContractVersions: getCompatibleContractVersions,
	}
}
